{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Uni-Mol Pocket Representation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Licenses**\n",
    "\n",
    "Copyright (c) DP Technology.\n",
    "\n",
    "This source code is licensed under the MIT license found in the\n",
    "LICENSE file in the root directory of this source tree.\n",
    "\n",
    "**Citations**\n",
    "\n",
    "Please cite the following papers if you use this notebook:\n",
    "\n",
    "- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n",
    "ChemRxiv (2022)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download pretrained pocket weights, and CASF-2016 data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "pocket_data_url='https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/CASF-2016.tar.gz'\n",
    "pocket_weight_url='https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/pocket_pre_220816.pt'\n",
    "wget -q ${pocket_data_url}\n",
    "tar -xzf \"CASF-2016.tar.gz\"\n",
    "wget -q ${pocket_weight_url}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read pocket information from CASF-2016 and save it to a .lmdb file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import lmdb\n",
    "from biopandas.pdb import PandasPdb\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import re\n",
    "import json\n",
    "import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CASF_PATH = \"CASF-2016\"\n",
    "main_atoms = [\"N\", \"CA\", \"C\", \"O\", \"H\"]\n",
    "\n",
    "def load_from_CASF(pdb_id):\n",
    "    try:\n",
    "        pdb_path = os.path.join(CASF_PATH, \"casf2016\", pdb_id + \"_protein.pdb\")\n",
    "        pmol = PandasPdb().read_pdb(pdb_path)\n",
    "        pocket_residues = json.load(\n",
    "            open(os.path.join(CASF_PATH, \"casf2016.pocket.json\"))\n",
    "        )[pdb_id]\n",
    "        return pmol, pocket_residues\n",
    "    except:\n",
    "        print(\"Currently not support parsing pdb and pocket info from local files.\")\n",
    "\n",
    "def normalize_atoms(atom):\n",
    "    return re.sub(\"\\d+\", \"\", atom)\n",
    "\n",
    "def parser(pdb_id):\n",
    "    pmol, pocket_residues = load_from_CASF(pdb_id)\n",
    "    pname = pdb_id\n",
    "    pro_atom = pmol.df[\"ATOM\"]\n",
    "    pro_hetatm = pmol.df[\"HETATM\"]\n",
    "\n",
    "    pro_atom[\"ID\"] = pro_atom[\"chain_id\"].astype(str) + pro_atom[\n",
    "        \"residue_number\"\n",
    "    ].astype(str)\n",
    "    pro_hetatm[\"ID\"] = pro_hetatm[\"chain_id\"].astype(str) + pro_hetatm[\n",
    "        \"residue_number\"\n",
    "    ].astype(str)\n",
    "\n",
    "    pocket = pd.concat(\n",
    "        [\n",
    "            pro_atom[pro_atom[\"ID\"].isin(pocket_residues)],\n",
    "            pro_hetatm[pro_hetatm[\"ID\"].isin(pocket_residues)],\n",
    "        ],\n",
    "        axis=0,\n",
    "        ignore_index=True,\n",
    "    )\n",
    "\n",
    "    pocket[\"normalize_atom\"] = pocket[\"atom_name\"].map(normalize_atoms)\n",
    "    pocket = pocket[pocket[\"normalize_atom\"] != \"\"]\n",
    "    patoms = pocket[\"atom_name\"].apply(normalize_atoms).values.tolist()\n",
    "    pcoords = [pocket[[\"x_coord\", \"y_coord\", \"z_coord\"]].values]\n",
    "    side = [0 if a in main_atoms else 1 for a in patoms]\n",
    "    residues = (\n",
    "        pocket[\"chain_id\"].astype(str) + pocket[\"residue_number\"].astype(str)\n",
    "    ).values.tolist()\n",
    "\n",
    "    return pickle.dumps(\n",
    "        {\n",
    "            \"atoms\": patoms,\n",
    "            \"coordinates\": pcoords,\n",
    "            \"side\": side,\n",
    "            \"residue\": residues,\n",
    "            \"pdbid\": pname,\n",
    "        },\n",
    "        protocol=-1,\n",
    "    )\n",
    "\n",
    "def write_lmdb(pdb_id_list, job_name, outpath=\"./results\"):\n",
    "    os.makedirs(outpath, exist_ok=True)\n",
    "    outputfilename = os.path.join(outpath, job_name + \".lmdb\")\n",
    "    try:\n",
    "        os.remove(outputfilename)\n",
    "    except:\n",
    "        pass\n",
    "    env_new = lmdb.open(\n",
    "        outputfilename,\n",
    "        subdir=False,\n",
    "        readonly=False,\n",
    "        lock=False,\n",
    "        readahead=False,\n",
    "        meminit=False,\n",
    "        max_readers=1,\n",
    "        map_size=int(10e9),\n",
    "    )\n",
    "    txn_write = env_new.begin(write=True)\n",
    "    for i, pdb_id in tqdm(enumerate(pdb_id_list)):\n",
    "        inner_output = parser(pdb_id)\n",
    "        txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n",
    "    txn_write.commit()\n",
    "    env_new.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_name = 'get_pocket_repr'   # replace to your custom name\n",
    "data_path = './results'  # replace to your data path\n",
    "weight_path='pocket_pre_220816.pt'  # replace to your ckpt path\n",
    "only_polar=0  # no h\n",
    "dict_name='dict_coarse.txt'\n",
    "batch_size=16\n",
    "results_path=data_path   # replace to your save path\n",
    "casf_collect = os.listdir(os.path.join(CASF_PATH, \"casf2016\"))\n",
    "casf_collect = list(set([item[:4] for item in casf_collect]))\n",
    "casf_collect.remove('3qgy')\n",
    "write_lmdb(casf_collect, job_name=job_name, outpath=data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Infer from pretrained pocket ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: Currently, the inference is only supported to run on a single GPU. You can add CUDA_VISIBLE_DEVICES=\"0\" before the command.\n",
    "!cp ../example_data/pocket/$dict_name $data_path\n",
    "!CUDA_VISIBLE_DEVICES=\"0\" python ../unimol/infer.py --user-dir ../unimol $data_path --valid-subset $job_name \\\n",
    "       --results-path $results_path \\\n",
    "       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
    "       --task unimol_pocket --loss unimol_infer --arch unimol_base \\\n",
    "       --path $weight_path \\\n",
    "       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
    "       --dict-name $dict_name \\\n",
    "       --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read .pkl and save results to .csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_csv_results(predict_path, results_path):\n",
    "    predict = pd.read_pickle(predict_path)\n",
    "    pdb_id_list, mol_repr_list, pair_repr_list = [], [], []\n",
    "    for batch in predict:\n",
    "        sz = batch[\"bsz\"]\n",
    "        for i in range(sz):\n",
    "            pdb_id_list.append(batch[\"data_name\"][i])\n",
    "            mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n",
    "            pair_repr_list.append(batch[\"pair_repr\"][i])\n",
    "    predict_df = pd.DataFrame({\"pdb_id\": pdb_id_list, \"mol_repr\": mol_repr_list, \"pair_repr\": pair_repr_list})\n",
    "    print(predict_df.head(1),predict_df.info())\n",
    "    predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n",
    "\n",
    "pkl_path = glob.glob(f'{results_path}/*_{job_name}.out.pkl')[0]\n",
    "get_csv_results(pkl_path, results_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
